{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "accelerator": "GPU",
    "colab": {
      "name": "SpeechDenoiserCNN.ipynb",
      "provenance": [],
      "collapsed_sections": [],
      "toc_visible": true,
      "machine_shape": "hm"
    },
    "kernelspec": {
      "display_name": "Python [conda env:tf2]",
      "language": "python",
      "name": "conda-env-tf2-py"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.6.12"
    }
  },
  "cells": [
    {
      "cell_type": "code",
      "metadata": {
        "id": "VoyZ5yLRaAeI"
      },
      "source": [
        "import tensorflow as tf\n",
        "device_name = tf.test.gpu_device_name()\n",
        "if device_name != '/device:GPU:0':\n",
        "  raise SystemError('GPU device not found')\n",
        "print('Found GPU at: {}'.format(device_name))"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "Ah7ISzvtEmur"
      },
      "source": [
        "import librosa\n",
        "import pandas as pd\n",
        "import os\n",
        "import datetime\n",
        "import matplotlib.pyplot as plt\n",
        "import numpy as np\n",
        "import tensorflow as tf\n",
        "import IPython.display as ipd\n",
        "import librosa.display\n",
        "import scipy\n",
        "import glob\n",
        "import numpy as np\n",
        "import math\n",
        "import warnings\n",
        "import pickle\n",
        "from sklearn.utils import shuffle\n",
        "import zipfile\n",
        "# Load the TensorBoard notebook extension.\n",
        "%load_ext tensorboard"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "WBZyUMpobR-Z"
      },
      "source": [
        "from tensorflow.python.client import device_lib\n",
        "device_lib.list_local_devices()"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "zu2lgR8lEmu0"
      },
      "source": [
        "tf.random.set_seed(999)\n",
        "np.random.seed(999)"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "JGbkvQHxxrj_"
      },
      "source": [
        "!wget 'cdn.daitan.com/dataset.zip'"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "51o6Ze9hxrkE"
      },
      "source": [
        "dataset_file_name = './dataset.zip'\n",
        "with zipfile.ZipFile(dataset_file_name, 'r') as zip_ref:\n",
        "    zip_ref.extractall('./dataset')"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "d43UvhEqxrkJ"
      },
      "source": [
        "path_to_dataset = \"./dataset/tfrecords\""
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "W36WF4mT1kGK"
      },
      "source": [
        "# get training and validation tf record file names\n",
        "train_tfrecords_filenames = glob.glob(os.path.join(path_to_dataset, 'train_*'))\n",
        "val_tfrecords_filenames = glob.glob(os.path.join(path_to_dataset, 'val_*'))\n",
        "\n",
        "# suffle the file names for training\n",
        "np.random.shuffle(train_tfrecords_filenames)\n",
        "print(\"Training file names: \", train_tfrecords_filenames)\n",
        "print(\"Validation file names: \", val_tfrecords_filenames)"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "jKZtPoLMEmvJ"
      },
      "source": [
        "windowLength = 256\n",
        "overlap      = round(0.25 * windowLength) # overlap of 75%\n",
        "ffTLength    = windowLength\n",
        "inputFs      = 48e3\n",
        "fs           = 16e3\n",
        "numFeatures  = ffTLength//2 + 1\n",
        "numSegments  = 8\n",
        "print(\"windowLength:\",windowLength)\n",
        "print(\"overlap:\",overlap)\n",
        "print(\"ffTLength:\",ffTLength)\n",
        "print(\"inputFs:\",inputFs)\n",
        "print(\"fs:\",fs)\n",
        "print(\"numFeatures:\",numFeatures)\n",
        "print(\"numSegments:\",numSegments)"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "JXpSGe8L0cl_"
      },
      "source": [
        "mozilla_basepath = \"./dataset/en\"\n",
        "UrbanSound8K_basepath = './dataset/UrbanSound8K'"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "zzbdfIi-Lgk9"
      },
      "source": [
        "## Prepare Input features"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "24kph3wJ3s5V"
      },
      "source": [
        "def tf_record_parser(record):\n",
        "    keys_to_features = {\n",
        "        \"noise_stft_phase\": tf.io.FixedLenFeature((), tf.string, default_value=\"\"),\n",
        "        'noise_stft_mag_features': tf.io.FixedLenFeature([], tf.string),\n",
        "        \"clean_stft_magnitude\": tf.io.FixedLenFeature((), tf.string)\n",
        "    }\n",
        "\n",
        "    features = tf.io.parse_single_example(record, keys_to_features)\n",
        "\n",
        "    noise_stft_mag_features = tf.io.decode_raw(features['noise_stft_mag_features'], tf.float32)\n",
        "    clean_stft_magnitude = tf.io.decode_raw(features['clean_stft_magnitude'], tf.float32)\n",
        "    noise_stft_phase = tf.io.decode_raw(features['noise_stft_phase'], tf.float32)\n",
        "\n",
        "    # reshape input and annotation images\n",
        "    noise_stft_mag_features = tf.reshape(noise_stft_mag_features, (129, 8, 1), name=\"noise_stft_mag_features\")\n",
        "    clean_stft_magnitude = tf.reshape(clean_stft_magnitude, (129, 1, 1), name=\"clean_stft_magnitude\")\n",
        "    noise_stft_phase = tf.reshape(noise_stft_phase, (129,), name=\"noise_stft_phase\")\n",
        "\n",
        "    return noise_stft_mag_features, clean_stft_magnitude"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "uJXPGrgVTCbZ"
      },
      "source": [
        "## Create tf.Data.Dataset"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "FF_A3YbZTCsj"
      },
      "source": [
        "train_dataset = tf.data.TFRecordDataset([train_tfrecords_filenames])\n",
        "train_dataset = train_dataset.map(tf_record_parser)\n",
        "train_dataset = train_dataset.shuffle(8192)\n",
        "train_dataset = train_dataset.repeat()\n",
        "train_dataset = train_dataset.batch(512)\n",
        "train_dataset = train_dataset.prefetch(buffer_size=tf.data.experimental.AUTOTUNE)"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "NOuWPzQaTNDy"
      },
      "source": [
        "test_dataset = tf.data.TFRecordDataset([val_tfrecords_filenames])\n",
        "test_dataset = test_dataset.map(tf_record_parser)\n",
        "test_dataset = test_dataset.repeat(1)\n",
        "test_dataset = test_dataset.batch(512)"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "CEG5JofLSfRw"
      },
      "source": [
        "## Model Training"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "2OZFegXrSYle"
      },
      "source": [
        "from tensorflow.keras.layers import Conv2D, Input, LeakyReLU, Flatten, Dense, Reshape, Conv2DTranspose, BatchNormalization, Activation\n",
        "from tensorflow.keras import Model, Sequential"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "lc2K0cfhPU5-"
      },
      "source": [
        "def conv_block(x, filters, kernel_size, strides, padding='same', use_bn=True):\n",
        "  x = Conv2D(filters=filters, kernel_size=kernel_size, strides=strides, padding=padding, use_bias=False,\n",
        "              kernel_regularizer=tf.keras.regularizers.l2(0.0006))(x)\n",
        "  x = Activation('relu')(x)\n",
        "  if use_bn:\n",
        "    x = BatchNormalization()(x)\n",
        "  return x"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "XFYyKoAVQYCz"
      },
      "source": [
        "def full_pre_activation_block(x, filters, kernel_size, strides, padding='same', use_bn=True):\n",
        "  shortcut = x\n",
        "  in_channels = x.shape[-1]\n",
        "\n",
        "  x = BatchNormalization()(x)\n",
        "  x = Activation('relu')(x)\n",
        "  x = Conv2D(filters=filters, kernel_size=kernel_size, strides=strides, padding='same')(x)\n",
        "\n",
        "  x = BatchNormalization()(x)\n",
        "  x = Activation('relu')(x)\n",
        "  x = Conv2D(filters=in_channels, kernel_size=kernel_size, strides=strides, padding='same')(x)\n",
        "\n",
        "  return shortcut + x"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "NlRQ3ngpZG1Y"
      },
      "source": [
        "def build_model(l2_strength):\n",
        "  inputs = Input(shape=[numFeatures,numSegments,1])\n",
        "  x = inputs\n",
        "\n",
        "  # -----\n",
        "  x = tf.keras.layers.ZeroPadding2D(((4,4), (0,0)))(x)\n",
        "  x = Conv2D(filters=18, kernel_size=[9,8], strides=[1, 1], padding='valid', use_bias=False,\n",
        "              kernel_regularizer=tf.keras.regularizers.l2(l2_strength))(x)\n",
        "  x = Activation('relu')(x)\n",
        "  x = BatchNormalization()(x)\n",
        "\n",
        "  skip0 = Conv2D(filters=30, kernel_size=[5,1], strides=[1, 1], padding='same', use_bias=False,\n",
        "                 kernel_regularizer=tf.keras.regularizers.l2(l2_strength))(x)\n",
        "  x = Activation('relu')(skip0)\n",
        "  x = BatchNormalization()(x)\n",
        "\n",
        "  x = Conv2D(filters=8, kernel_size=[9,1], strides=[1, 1], padding='same', use_bias=False,\n",
        "              kernel_regularizer=tf.keras.regularizers.l2(l2_strength))(x)\n",
        "  x = Activation('relu')(x)\n",
        "  x = BatchNormalization()(x)\n",
        "\n",
        "  # -----\n",
        "  x = Conv2D(filters=18, kernel_size=[9,1], strides=[1, 1], padding='same', use_bias=False,\n",
        "              kernel_regularizer=tf.keras.regularizers.l2(l2_strength))(x)\n",
        "  x = Activation('relu')(x)\n",
        "  x = BatchNormalization()(x)\n",
        "\n",
        "  skip1 = Conv2D(filters=30, kernel_size=[5,1], strides=[1, 1], padding='same', use_bias=False,\n",
        "                 kernel_regularizer=tf.keras.regularizers.l2(l2_strength))(x)\n",
        "  x = Activation('relu')(skip1)\n",
        "  x = BatchNormalization()(x)\n",
        "\n",
        "  x = Conv2D(filters=8, kernel_size=[9,1], strides=[1, 1], padding='same', use_bias=False,\n",
        "              kernel_regularizer=tf.keras.regularizers.l2(l2_strength))(x)\n",
        "  x = Activation('relu')(x)\n",
        "  x = BatchNormalization()(x)\n",
        "\n",
        "  # ----\n",
        "  x = Conv2D(filters=18, kernel_size=[9,1], strides=[1, 1], padding='same', use_bias=False,\n",
        "              kernel_regularizer=tf.keras.regularizers.l2(l2_strength))(x)\n",
        "  x = Activation('relu')(x)\n",
        "  x = BatchNormalization()(x)\n",
        "  \n",
        "  x = Conv2D(filters=30, kernel_size=[5,1], strides=[1, 1], padding='same', use_bias=False,\n",
        "              kernel_regularizer=tf.keras.regularizers.l2(l2_strength))(x)\n",
        "  x = Activation('relu')(x)\n",
        "  x = BatchNormalization()(x)\n",
        "\n",
        "  x = Conv2D(filters=8, kernel_size=[9,1], strides=[1, 1], padding='same', use_bias=False,\n",
        "              kernel_regularizer=tf.keras.regularizers.l2(l2_strength))(x)\n",
        "  x = Activation('relu')(x)\n",
        "  x = BatchNormalization()(x)\n",
        "\n",
        "  # ----\n",
        "  x = Conv2D(filters=18, kernel_size=[9,1], strides=[1, 1], padding='same', use_bias=False,\n",
        "              kernel_regularizer=tf.keras.regularizers.l2(l2_strength))(x)\n",
        "  x = Activation('relu')(x)\n",
        "  x = BatchNormalization()(x)\n",
        "\n",
        "  x = Conv2D(filters=30, kernel_size=[5,1], strides=[1, 1], padding='same', use_bias=False,\n",
        "             kernel_regularizer=tf.keras.regularizers.l2(l2_strength))(x)\n",
        "  x = x + skip1\n",
        "  x = Activation('relu')(x)\n",
        "  x = BatchNormalization()(x)\n",
        "\n",
        "  x = Conv2D(filters=8, kernel_size=[9,1], strides=[1, 1], padding='same', use_bias=False,\n",
        "              kernel_regularizer=tf.keras.regularizers.l2(l2_strength))(x)\n",
        "  x = Activation('relu')(x)\n",
        "  x = BatchNormalization()(x)\n",
        "\n",
        "  # ----\n",
        "  x = Conv2D(filters=18, kernel_size=[9,1], strides=[1, 1], padding='same', use_bias=False,\n",
        "              kernel_regularizer=tf.keras.regularizers.l2(l2_strength))(x)\n",
        "  x = Activation('relu')(x)\n",
        "  x = BatchNormalization()(x)\n",
        "\n",
        "  x = Conv2D(filters=30, kernel_size=[5,1], strides=[1, 1], padding='same', use_bias=False,\n",
        "             kernel_regularizer=tf.keras.regularizers.l2(l2_strength))(x)\n",
        "  x = x + skip0\n",
        "  x = Activation('relu')(x)\n",
        "  x = BatchNormalization()(x)\n",
        "\n",
        "  x = Conv2D(filters=8, kernel_size=[9,1], strides=[1, 1], padding='same', use_bias=False,\n",
        "              kernel_regularizer=tf.keras.regularizers.l2(l2_strength))(x)\n",
        "  x = Activation('relu')(x)\n",
        "  x = BatchNormalization()(x)\n",
        "\n",
        "  # ----\n",
        "  x = tf.keras.layers.SpatialDropout2D(0.2)(x)\n",
        "  x = Conv2D(filters=1, kernel_size=[129,1], strides=[1, 1], padding='same')(x)\n",
        "\n",
        "  model = Model(inputs=inputs, outputs=x)\n",
        "\n",
        "  optimizer = tf.keras.optimizers.Adam(3e-4)\n",
        "  #optimizer = RAdam(total_steps=10000, warmup_proportion=0.1, min_lr=3e-4)\n",
        "\n",
        "  model.compile(optimizer=optimizer, loss='mse', \n",
        "                metrics=[tf.keras.metrics.RootMeanSquaredError('rmse')])\n",
        "  return model"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "mNpHS4LuShxd"
      },
      "source": [
        "model = build_model(l2_strength=0.0)\n",
        "model.summary()"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "8rWOgsdwglFE"
      },
      "source": [
        "# You might need to install the following dependencies: sudo apt install python-pydot python-pydot-ng graphviz\n",
        "tf.keras.utils.plot_model(model, show_shapes=True, dpi=64)"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "7OdmgHTp4_Ou"
      },
      "source": [
        "%tensorboard --logdir logs"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "BucSAwkHQj8Q"
      },
      "source": [
        "baseline_val_loss = model.evaluate(test_dataset)[0]\n",
        "print(f\"Baseline accuracy {baseline_val_loss}\")"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "ocycFbP5-X0o"
      },
      "source": [
        "def l2_norm(vector):\n",
        "    return np.square(vector)\n",
        "\n",
        "def SDR(denoised, cleaned, eps=1e-7): # Signal to Distortion Ratio\n",
        "    a = l2_norm(denoised)\n",
        "    b = l2_norm(denoised - cleaned)\n",
        "    a_b = a / b\n",
        "    return np.mean(10 * np.log10(a_b + eps))"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "bcPAuMZ9SlHa"
      },
      "source": [
        "early_stopping_callback = tf.keras.callbacks.EarlyStopping(monitor='val_loss', patience=50, restore_best_weights=True, baseline=None)\n",
        "\n",
        "logdir = os.path.join(\"logs\", datetime.datetime.now().strftime(\"%Y%m%d-%H%M%S\"))\n",
        "tensorboard_callback = tf.keras.callbacks.TensorBoard(logdir, update_freq='batch')\n",
        "checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(filepath='./denoiser_cnn_log_mel_generator.h5', \n",
        "                                                         monitor='val_loss', save_best_only=True)\n",
        "\n",
        "model.fit(train_dataset,\n",
        "         steps_per_epoch=600, # you might need to change this\n",
        "         validation_data=test_dataset,\n",
        "         epochs=400,\n",
        "         callbacks=[early_stopping_callback, tensorboard_callback, checkpoint_callback]\n",
        "        )"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "g1ZZskDZ5L2a"
      },
      "source": [
        "val_loss = model.evaluate(test_dataset)[0]\n",
        "if val_loss < baseline_val_loss:\n",
        "  print(\"New model saved.\")\n",
        "  model.save('./denoiser_cnn_log_mel_generator.h5')"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "LeJTsGxCSuhm"
      },
      "source": [
        "## Testing"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "DLiTiK8blE0n"
      },
      "source": [
        "def read_audio(filepath, sample_rate, normalize=True):\n",
        "    \"\"\"Read an audio file and return it as a numpy array\"\"\"\n",
        "    audio, sr = librosa.load(filepath, sr=sample_rate)\n",
        "    if normalize:\n",
        "      div_fac = 1 / np.max(np.abs(audio)) / 3.0\n",
        "      audio = audio * div_fac\n",
        "    return audio, sr\n",
        "        \n",
        "def add_noise_to_clean_audio(clean_audio, noise_signal):\n",
        "    \"\"\"Adds noise to an audio sample\"\"\"\n",
        "    if len(clean_audio) >= len(noise_signal):\n",
        "        # print(\"The noisy signal is smaller than the clean audio input. Duplicating the noise.\")\n",
        "        while len(clean_audio) >= len(noise_signal):\n",
        "            noise_signal = np.append(noise_signal, noise_signal)\n",
        "\n",
        "    ## Extract a noise segment from a random location in the noise file\n",
        "    ind = np.random.randint(0, noise_signal.size - clean_audio.size)\n",
        "\n",
        "    noiseSegment = noise_signal[ind: ind + clean_audio.size]\n",
        "\n",
        "    speech_power = np.sum(clean_audio ** 2)\n",
        "    noise_power = np.sum(noiseSegment ** 2)\n",
        "    noisyAudio = clean_audio + np.sqrt(speech_power / noise_power) * noiseSegment\n",
        "    return noisyAudio\n",
        "\n",
        "def play(audio, sample_rate):\n",
        "    ipd.display(ipd.Audio(data=audio, rate=sample_rate))  # load a local WAV file"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "uM6ajbBFlx3b"
      },
      "source": [
        "class FeatureExtractor:\n",
        "    def __init__(self, audio, *, windowLength, overlap, sample_rate):\n",
        "        self.audio = audio\n",
        "        self.ffT_length = windowLength\n",
        "        self.window_length = windowLength\n",
        "        self.overlap = overlap\n",
        "        self.sample_rate = sample_rate\n",
        "        self.window = scipy.signal.hamming(self.window_length, sym=False)\n",
        "\n",
        "    def get_stft_spectrogram(self):\n",
        "        return librosa.stft(self.audio, n_fft=self.ffT_length, win_length=self.window_length, hop_length=self.overlap,\n",
        "                            window=self.window, center=True)\n",
        "\n",
        "    def get_audio_from_stft_spectrogram(self, stft_features):\n",
        "        return librosa.istft(stft_features, win_length=self.window_length, hop_length=self.overlap,\n",
        "                             window=self.window, center=True)\n",
        "\n",
        "    def get_mel_spectrogram(self):\n",
        "        return librosa.feature.melspectrogram(self.audio, sr=self.sample_rate, power=2.0, pad_mode='reflect',\n",
        "                                           n_fft=self.ffT_length, hop_length=self.overlap, center=True)\n",
        "\n",
        "    def get_audio_from_mel_spectrogram(self, M):\n",
        "        return librosa.feature.inverse.mel_to_audio(M, sr=self.sample_rate, n_fft=self.ffT_length, hop_length=self.overlap,\n",
        "                                             win_length=self.window_length, window=self.window,\n",
        "                                             center=True, pad_mode='reflect', power=2.0, n_iter=32, length=None)"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "9dcnyvquSoLs"
      },
      "source": [
        "cleanAudio, sr = read_audio(os.path.join(mozilla_basepath, 'test', 'common_voice_en_16526.mp3'), sample_rate=fs)\n",
        "print(\"Min:\", np.min(cleanAudio),\"Max:\",np.max(cleanAudio))\n",
        "ipd.Audio(data=cleanAudio, rate=sr) # load a local WAV file"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "BaHe1okPTvV-"
      },
      "source": [
        "noiseAudio, sr = read_audio(os.path.join(UrbanSound8K_basepath, 'test', '7913-3-0-0.wav'), sample_rate=fs)\n",
        "print(\"Min:\", np.min(noiseAudio),\"Max:\",np.max(noiseAudio))\n",
        "ipd.Audio(data=noiseAudio, rate=sr) # load a local WAV file"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "bu_eOKRfTHbp"
      },
      "source": [
        "cleanAudioFeatureExtractor = FeatureExtractor(cleanAudio, windowLength=windowLength, overlap=overlap, sample_rate=sr)\n",
        "stft_features = cleanAudioFeatureExtractor.get_stft_spectrogram()\n",
        "stft_features = np.abs(stft_features)\n",
        "print(\"Min:\", np.min(stft_features),\"Max:\",np.max(stft_features))"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "sHWcmobyTP4E"
      },
      "source": [
        "noisyAudio = add_noise_to_clean_audio(cleanAudio, noiseAudio)\n",
        "ipd.Audio(data=noisyAudio, rate=fs) # load a local WAV file"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "75M29dl3bBeF"
      },
      "source": [
        "def prepare_input_features(stft_features):\n",
        "    # Phase Aware Scaling: To avoid extreme differences (more than\n",
        "    # 45 degree) between the noisy and clean phase, the clean spectral magnitude was encoded as similar to [21]:\n",
        "    noisySTFT = np.concatenate([stft_features[:,0:numSegments-1], stft_features], axis=1)\n",
        "    stftSegments = np.zeros((numFeatures, numSegments , noisySTFT.shape[1] - numSegments + 1))\n",
        "\n",
        "    for index in range(noisySTFT.shape[1] - numSegments + 1):\n",
        "        stftSegments[:,:,index] = noisySTFT[:,index:index + numSegments]\n",
        "    return stftSegments"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "9cjO5-cjTP6t"
      },
      "source": [
        "noiseAudioFeatureExtractor = FeatureExtractor(noisyAudio, windowLength=windowLength, overlap=overlap, sample_rate=sr)\n",
        "noise_stft_features = noiseAudioFeatureExtractor.get_stft_spectrogram()\n",
        "\n",
        "# Paper: Besides, spectral phase was not used in the training phase.\n",
        "# At reconstruction, noisy spectral phase was used instead to\n",
        "# perform in- verse STFT and recover human speech.\n",
        "noisyPhase = np.angle(noise_stft_features)\n",
        "print(noisyPhase.shape)\n",
        "noise_stft_features = np.abs(noise_stft_features)\n",
        "\n",
        "mean = np.mean(noise_stft_features)\n",
        "std = np.std(noise_stft_features)\n",
        "noise_stft_features = (noise_stft_features - mean) / std"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "3p5PWWkrlE3m"
      },
      "source": [
        "predictors = prepare_input_features(noise_stft_features)"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "4pSoOw2fTP9N"
      },
      "source": [
        "predictors = np.reshape(predictors, (predictors.shape[0], predictors.shape[1], 1, predictors.shape[2]))\n",
        "predictors = np.transpose(predictors, (3, 0, 1, 2)).astype(np.float32)\n",
        "print('predictors.shape:', predictors.shape)"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "a5sNR4sSTP_1"
      },
      "source": [
        "STFTFullyConvolutional = model.predict(predictors)\n",
        "print(STFTFullyConvolutional.shape)"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "LzCga3PdUVwG"
      },
      "source": [
        "def revert_features_to_audio(features, phase, cleanMean=None, cleanStd=None):\n",
        "    # scale the outpus back to the original range\n",
        "    if cleanMean and cleanStd:\n",
        "        features = cleanStd * features + cleanMean\n",
        "\n",
        "    phase = np.transpose(phase, (1, 0))\n",
        "    features = np.squeeze(features)\n",
        "\n",
        "    # features = librosa.db_to_power(features)\n",
        "    features = features * np.exp(1j * phase)  # that fixes the abs() ope previously done\n",
        "\n",
        "    features = np.transpose(features, (1, 0))\n",
        "    return noiseAudioFeatureExtractor.get_audio_from_stft_spectrogram(features)"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "LWlUDtPzURlQ"
      },
      "source": [
        "denoisedAudioFullyConvolutional = revert_features_to_audio(STFTFullyConvolutional, noisyPhase, mean, std)\n",
        "print(\"Min:\", np.min(denoisedAudioFullyConvolutional),\"Max:\",np.max(denoisedAudioFullyConvolutional))\n",
        "ipd.Audio(data=denoisedAudioFullyConvolutional, rate=fs) # load a local WAV file"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "VvEIBy7EHgeP"
      },
      "source": [
        "# A numeric identifier of the sound class -- Types of noise\n",
        "# 0 = air_conditioner\n",
        "# 1 = car_horn\n",
        "# 2 = children_playing\n",
        "# 3 = dog_bark\n",
        "# 4 = drilling\n",
        "# 5 = engine_idling\n",
        "# 6 = gun_shot\n",
        "# 7 = jackhammer\n",
        "# 8 = siren\n",
        "# 9 = street_music"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "tv_7ZwWaUW0_"
      },
      "source": [
        "f, (ax1, ax2, ax3) = plt.subplots(3, 1, sharey=True)\n",
        "\n",
        "ax1.plot(cleanAudio)\n",
        "ax1.set_title(\"Clean Audio\")\n",
        "\n",
        "ax2.plot(noisyAudio)\n",
        "ax2.set_title(\"Noisy Audio\")\n",
        "\n",
        "ax3.plot(denoisedAudioFullyConvolutional)\n",
        "ax3.set_title(\"Denoised Audio\")"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "8LM7E91avKA3"
      },
      "source": [
        ""
      ],
      "execution_count": null,
      "outputs": []
    }
  ]
}
